#include "easy_grpc/easy_grpc.h"

#include "generated/test.egrpc.pb.h"

#include "gtest/gtest.h"

namespace rpc = easy_grpc;

namespace {
class Test_sync_impl : public tests::TestService {
 public:
  ::rpc::Future<::tests::TestReply> TestMethod(::tests::TestRequest req, ::easy_grpc::Context ctx) override {
    ::tests::TestReply result;

    result.set_name(req.name() + "_replied_with_" + std::string(ctx.get_client_header("expected")));

    ::rpc::Promise<::tests::TestReply> prom;
    auto f = prom.get_future();
    prom.set_value(result);

    return f;
  }
};
}  // namespace

TEST(test_headers, missing_header) {
  rpc::Environment grpc_env;

  std::array<rpc::Completion_queue, 1> server_queues;
  rpc::Completion_queue client_queue;

  Test_sync_impl sync_srv;

  int server_port = 0;
  rpc::server::Server server = std::move(rpc::server::Config()
                                             .add_default_listening_queues({server_queues.begin(), server_queues.end()})
                                             .add_service(sync_srv)
                                             .add_listening_port("127.0.0.1:0", {}, &server_port));

  EXPECT_NE(0, server_port);

  {
    rpc::client::Unsecure_channel channel(std::string("127.0.0.1:") + std::to_string(server_port), &client_queue);
    tests::TestService::Stub stub(&channel);

    ::tests::TestRequest req;
    req.set_name("dude");

    // forget to send the expected header, we should have a failure
    EXPECT_THROW(stub.TestMethod(req).get().name(), easy_grpc::Rpc_error);
  }
}

TEST(test_headers, unary) {
  rpc::Environment grpc_env;

  std::array<rpc::Completion_queue, 1> server_queues;
  rpc::Completion_queue client_queue;

  Test_sync_impl sync_srv;

  int server_port = 0;
  rpc::server::Server server = std::move(rpc::server::Config()
                                             .add_default_listening_queues({server_queues.begin(), server_queues.end()})
                                             .add_service(sync_srv)
                                             .add_listening_port("127.0.0.1:0", {}, &server_port));

  EXPECT_NE(0, server_port);

  {
    rpc::client::Unsecure_channel channel(std::string("127.0.0.1:") + std::to_string(server_port), &client_queue);
    tests::TestService::Stub stub(&channel);

    ::tests::TestRequest req;
    req.set_name("dude");

    rpc::client::Call_options options;
    std::vector<grpc_metadata> headers;

    grpc_metadata trial_header;
    trial_header.key = grpc_slice_from_static_string("expected");
    trial_header.value = grpc_slice_from_static_string("abc");
    headers.push_back(trial_header);
    options.headers = &headers;

    EXPECT_EQ(stub.TestMethod(req, options).get().name(), "dude_replied_with_abc");

    grpc_slice_unref(trial_header.key);
    grpc_slice_unref(trial_header.value);
  }
}
